from Model.model_utils import init_passive, init_active_inter_all, init_single_passive, init_pair, init_active_inter_full
from Model.InferenceModule.passive_module import PassiveModule
from Model.InferenceModule.passive_single_module import SinglePassiveModule
from Model.InferenceModule.full_module import FullModule
from Model.InferenceModule.all_module import AllModule
from Model.InferenceModule.pairwise_module import PairModule
from Model.InferenceModule.module_utils import init_dists
from ACState.extractor import get_factor_params
from tianshou.data import Batch
from Network.network_utils import initialize_optimizer
import numpy as np

class InferenceModel():
    def __init__(self, args, extractor, normalizer, environment):
        '''
        stores the networks and their corresponding optimizers
        this model is build out of a collection of submodels, 
        which are initialized in this init. 
        it also houses the extractors, but these are passed in
        '''
        self.mp = args.inter
        self.fp = args.factor

        args.interaction_net.factor = args.factor
        self.all_passive_model = init_passive(args)
        self.all_model, self.all_inter_model = init_active_inter_all(args)
        #TODO: first_obj_dim is different for different objects, store in factor and access appropriately
        self.single_passive_models = init_single_passive(args)
        self.pair_names = args.inter.pair_names
        self.pair_models = init_pair(args, extractor)
        self.full_models, self.inter_models = init_active_inter_full(args)
        self.trace_idx = np.array([extractor.names.index(n) for n in args.inter.train_names]) if self.all_model is None else np.arange(len(extractor.names))
        self.train_names = args.inter.train_names

        self.name =  "_".join([args.environment.env] + [args.environment.variant] + args.inter.train_names) 
        self.dists = init_dists(args)

        # the passive module computes the passive output for ALL factors
        self.passive_module = PassiveModule(args, extractor, self.dists.forward, self.all_passive_model, self.all_model)
        
        # uses a different model for each passive output in args.inter.single_passive
        self.single_passive_modules = dict()
        for sp in self.mp.train_names:
            self.single_passive_modules[sp] = SinglePassiveModule(args, extractor, sp, self.dists.forward, self.single_passive_models[sp] if sp in self.single_passive_models else None, self.full_models[sp] if sp in self.full_models else None, self.all_passive_model, self.all_model)
        
        # uses a different model for each pairwise output in args.inter.pairs
        self.pair_modules = dict()
        for pair in self.mp.pair_names:
            target = pair.split('->')[-1]
            self.pair_modules[pair] = PairModule(args, extractor, pair, self.dists.forward, self.pair_models[pair] if pair in self.pair_models else None, self.full_models[target] if target in self.full_models else None, self.all_model)
        
        # uses a different model for each full output in args.inter.full_names
        self.full_modules = dict()
        for name in self.mp.train_names:
            self.full_modules[name] = FullModule(args, extractor, name, self.dists, self.inter_models[sp] if sp in self.inter_models else None, self.all_inter_model, self.full_models[sp] if sp in self.full_models else None, self.all_model)

        # uses a single model for all values
        self.all_module = AllModule(args, extractor, self.dists, self.all_inter_model, self.all_model)
        self.regenerate(extractor, normalizer, environment)

        # set up cuda (TODO: try to limit as much of the torch dependency as possible)
        self.cuda(device = args.torch.gpu) if args.torch.cuda else self.cpu()
        self.device = args.torch.gpu
        print(self.full_modules)
    
    def cuda(self, device=None):
        self.iscuda = True
        if device is not None: self.device = device
        self.passive_module.cuda(device=device)
        for sp in self.mp.train_names:
            self.single_passive_modules[sp].cuda(device=device)
        for pair in self.mp.pair_names:
            self.pair_modules[pair].cuda(device=device)
        for name in self.mp.train_names:
            self.full_modules[name].cuda(device=device)
        self.all_module.cuda(device=device)
        return self

    def cpu(self):
        self.iscuda = False
        self.passive_module.cpu()
        for sp in self.mp.train_names:
            self.single_passive_modules[sp].cpu()
        for pair in self.mp.pair_names:
            self.pair_modules[pair].cpu()
        for name in self.mp.train_names:
            self.full_modules[name].cpu()
        self.all_module.cpu()
        return self
    
    def set_target_name(self, target_name):
        if target_name.find('->') != -1:
            self.target_name = target_name.split('->')[-1]
            self.pair_name = target_name
        else:
            self.target_name = target_name
            self.pair_name = "UNUSED"

    def regenerate(self, extractor, normalizer, environment):
        self.extractor = extractor
        self.normalizer = normalizer
        self.passive_module.extractor = extractor
        for sp in self.mp.train_names:
            self.single_passive_modules[sp].extractor = extractor
        for pair in self.mp.pair_names:
            self.pair_modules[pair].extractor = extractor
        for name in self.mp.train_names:
            self.full_modules[name].extractor = extractor
        self.all_module.extractor = extractor
        self.factor_params = get_factor_params(extractor)
    
    def passive_mask(self, batch, name, form):
        # returns the passive mask for a given batch size and form (all or full)
        # batch, num_name_instances, num_objects or batch, num_objects, num_objects
        if form == "full": 
            mask = np.zeros((batch, self.extractor.num_objects))
            mask[:,self.extractor.get_index([name])] = 1
            idxes = self.extractor.get_index([name])
            mask = np.broadcast_to(np.expand_dims(mask, axis=1), (batch, len(idxes), self.extractor.num_objects))
        else:
            mask = np.broadcast_to(np.expand_dims(np.eye(self.extractor.num_objects), axis=0), (batch, self.extractor.num_objects, self.extractor.num_objects))
        return mask

    
    def infer(self, batch, given_mask, infer_type, additional=[], grad_settings=[], log_batch=[], keep_invalid=False, keep_all=False):
        '''
        takes in a batch and converts to pytorch
            Batch needs keys:
                obs
                target
                next_obs
                next_target
        computes all of the necessary components, including internal masks
        type should be one of the initialized submodels
        supported inference types:
            passive
            single_passive
            pair
            full
            mask (masked forward)
            probs (returns just the mask)
            all
            all_mask
            all_probs
        additional parameters: (will result in additional values in result.module_type, if applicable)
            attention
            gradient # TODO: since this requires an optimization step, probably not immediately usable
            hard
            soft
            mixed
            flat
        names computes only those particular names, for name-specific models
        log_batch: keeps in result certain values from the batch, usually for logging. Typically:
            trace
        @param keep_invalid will not filter out outcomes which are not valid (for full, pair, mask or single_passive)
        @param keep_all will keep done and invalid values in the output
        @param given_mask is valid * the given mask, for given inference type (otherwise just valid)
        '''
        result = Batch()
        if 'passive' in infer_type:
            result.passive = self.passive_module(batch, given_mask, self.extractor, self.normalizer, additional=additional, grad_settings=grad_settings, log_batch=log_batch, keep_all=keep_all)
        if 'single_passive' in infer_type:
            result.single_passive = self.single_passive_modules[self.target_name](batch, given_mask, self.extractor, self.normalizer, additional=additional, grad_settings=grad_settings, log_batch=log_batch, keep_invalid=keep_invalid, keep_all=keep_all)
        if 'pair' in infer_type:
            result.pair = self.pair_modules[self.pair_name](batch, given_mask, self.extractor, self.normalizer, additional=additional, grad_settings=grad_settings, log_batch=log_batch, keep_invalid=keep_invalid, keep_all=keep_all)
        if 'full' in infer_type: # uses the full model with the ones * given mask
            result.full = self.full_modules[self.target_name](batch, given_mask, self.extractor, self.normalizer, additional=additional, grad_settings=grad_settings, log_batch=log_batch, full=True, keep_invalid=keep_invalid, keep_all=keep_all)
        if 'mask' in infer_type: # uses the full model but infers the mask, see additional parameters for what kind of mask is used
            result.mask = self.full_modules[self.target_name](batch, given_mask, self.extractor, self.normalizer, additional=additional, grad_settings=grad_settings, log_batch=log_batch, full=False, keep_invalid=keep_invalid, keep_all=keep_all)
        if 'probs' in infer_type: # uses the full model but infers the mask, see additional parameters for what kind of mask is used
            result.probs = self.full_modules[self.target_name](batch, given_mask, self.extractor, self.normalizer, additional=additional, grad_settings=grad_settings, log_batch=log_batch, full=False, keep_invalid=keep_invalid, keep_all=keep_all, probs=True)
        if 'all' in infer_type:
            result.all = self.all_module(batch, given_mask, self.extractor, self.normalizer, additional=additional, grad_settings=grad_settings, log_batch=log_batch, full=True, keep_all=keep_all)
        if 'all_mask' in infer_type:
            result.all_mask = self.all_module(batch, given_mask, self.extractor, self.normalizer, additional=additional, grad_settings=grad_settings, log_batch=log_batch, full=False, keep_all=keep_all)
        if 'all_probs' in infer_type: # uses the full model but infers the mask, see additional parameters for what kind of mask is used
            result.all_probs = self.all_module(batch, given_mask, self.extractor, self.normalizer, additional=additional, grad_settings=grad_settings, log_batch=log_batch, probs=True, keep_all=keep_all)
        return result
    
    def assign_module_from_model(self):
        # after loading, we need to assign the module's models to the new models
        for sp in self.mp.train_names:
            self.single_passive_modules[sp].assign_from_model(self.single_passive_models[sp])
        for pair in self.mp.pair_names:
            self.pair_modules[pair].assign_from_model(self.pair_models[pair])
        for name in self.mp.train_names:
            self.full_modules[name].assign_from_model(self.full_models[name], self.inter_models[name])
        self.all_module.assign_from_model(self.all_model, self.all_inter_model)
        self.passive_module.assign_from_model(self.all_passive_model)


    def get_module(self, infer_type):
        if 'passive' == infer_type: return self.passive_module
        if 'single_passive' == infer_type: return self.single_passive_modules[self.target_name]
        if 'pair' == infer_type: return self.pair_modules[self.pair_name]
        if 'full' == infer_type: return self.full_modules[self.target_name]
        if 'mask' == infer_type: return self.full_modules[self.target_name]
        if 'probs' == infer_type: return self.full_modules[self.target_name]
        if 'all' == infer_type: return self.all_module
        if 'all_mask' == infer_type: return self.all_module
        if 'all_probs' == infer_type: return self.all_module
        return None


    def get_model_optim(self, compute_types):
        '''
        returns the appropriate model and optimizer
        for a particular compute type
        returns the same order as input
        '''
        models = list()
        optimizers = list()
        for n in compute_types:
            if 'passive' == n:
                model, optimizer = self.passive_module.return_model_optimizer()
                models.append(model)
                optimizers.append(optimizer)
            if 'single_passive' == n:
                model, optimizer = self.single_passive_modules[self.target_name].return_model_optimizer()
                models.append(model)
                optimizers.append(optimizer)
            if 'pair' == n:
                model, optimizer = self.pair_modules[self.pair_name].return_model_optimizer()
                models.append(model)
                optimizers.append(optimizer)
            if 'full' == n:
                model, optimizer = self.full_modules[self.target_name].return_model_optimizer()
                models.append(model)
                optimizers.append(optimizer)
            if 'all' == n:
                model, optimizer = self.all_module.return_model_optimizer()
                models.append(model)
                optimizers.append(optimizer)
            if 'full_both' == n:
                model, optimizer = self.full_modules[self.target_name].return_model_optimizer(both=True)
                models.append(model)
                optimizers.append(optimizer)
            if 'all_both' == n:
                model, optimizer = self.all_module.return_model_optimizer(both=True)
                models.append(model)
                optimizers.append(optimizer)
            if 'full_inter' == n:
                model, optimizer = self.full_modules[self.target_name].return_model_optimizer(inter=True)
                models.append(model)
                optimizers.append(optimizer)
            if 'all_inter' == n:
                model, optimizer = self.all_module.return_model_optimizer(inter=True)
                models.append(model)
                optimizers.append(optimizer)
        return models, optimizers